# Copyright 2021-2024 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Chenghan Li <lch004218@gmail.com>

import numpy as np
import pyscf
from pyscf import lib
from pyscf import gto
from pyscf import df
from pyscf import grad
from pyscf.lib import logger

import cupy as cp
from gpu4pyscf import scf
from gpu4pyscf.qmmm.pbc import mm_mole
from gpu4pyscf.lib import cupy_helper
from gpu4pyscf.qmmm.pbc.tools import get_multipole_tensors_pp, get_multipole_tensors_pg
from gpu4pyscf.gto.int3c1e import int1e_grids
from gpu4pyscf.gto.int3c1e_ip import int1e_grids_ip1, int1e_grids_ip2

contract = cupy_helper.contract

from cupyx.scipy.special import erfc, erf

def add_mm_charges(scf_method, atoms_or_coords, a, charges, radii=None,
        rcut_ewald=None, rcut_hcore=None, unit=None):
    '''Embedding the one-electron (non-relativistic) potential generated by MM
    point charges into QM Hamiltonian.

    The total energy includes the regular QM energy, the interaction between
    the nuclei in QM region and the MM charges, and the static Coulomb
    interaction between the electron density and the MM charges. The electrostatic
    interactions between reference cell and periodic images are also computed at
    point charge level. It does not include the static Coulomb interactions
    of the MM point charges, the MM energy, the vdw interaction or other
    bonding/non-bonding effects between QM region and MM particles.

    Args:
        scf_method : a HF or DFT object

        atoms_or_coords : 2D array, shape (N,3)
            MM particle coordinates
        charges : 1D array
            MM particle charges
        a : 2D array, shape (3,3)
            Lattice vectors

    Kwargs:
        radii : 1D array
            The Gaussian charge distribution radii of MM atoms.
        rcut_ewald : float
            The real-space Ewald cutoff.
        rcut_hcore : float
            The cutoff for exact MM potential when computing hcore.
        unit : str
            Bohr, AU, Ang (case insensitive). Default is the same to mol.unit

    Returns:
        Same method object as the input scf_method with modified 1e Hamiltonia

    Examples:

    >>> mol = gto.M(atom='H 0 0 0; F 0 0 1', basis='ccpvdz', verbose=0)
    >>> mf = add_mm_charges(dft.RKS(mol), [(0.5,0.6,0.8)], np.eye(3)*10, [-0.3])
    >>> mf.kernel()
    '''
    mol = scf_method.mol
    if unit is None:
        unit = mol.unit
    mm_mol = mm_mole.create_mm_mol(atoms_or_coords, a, charges, radii=radii,
            rcut_ewald=rcut_ewald, rcut_hcore=rcut_hcore, unit=unit)
    return qmmm_for_scf(scf_method, mm_mol)

def qmmm_for_scf(method, mm_mol):
    '''Add the potential of MM particles to SCF (HF and DFT) method
    then generate the corresponding QM/MM method for the QM system.

    Args:
        mm_mol : MM Mole object
    '''
    if isinstance(method, scf.hf.SCF):
        # Avoid to initialize QMMM twice
        if isinstance(method, QMMM):
            method.mm_mol = mm_mol
            method.s1r = None
            method.s1rr = None
            method.mm_ewald_pot = None
            method.qm_ewald_hess = None
            method.e_nuc = None
            return method

        cls = QMMMSCF
    else:
        # post-HF methods
        raise NotImplementedError()

    return lib.set_class(cls(method, mm_mol), (cls, method.__class__))

class QMMM:
    __name_mixin__ = 'QMMM'

_QMMM = QMMM

class QMMMSCF(QMMM):
    _keys = {'mm_mol', 's1r', 's1rr', 'mm_ewald_pot', 'qm_ewald_hess', 'e_nuc'}

    to_cpu     = NotImplemented
    as_scanner = NotImplemented

    def __init__(self, method, mm_mol):
        self.__dict__.update(method.__dict__)
        self.mm_mol = mm_mol
        self.s1r = None
        self.s1rr = None
        self.mm_ewald_pot = None
        self.qm_ewald_hess = None
        self.e_nuc = None

    def dump_flags(self, verbose=None):
        super().dump_flags(verbose)
        logger.info(self, '** Add background charges for %s **',
                    self.__class__.__name__)
        _a = self.mm_mol.lattice_vectors()
        logger.info(self, 'lattice vectors  a1 [%.9f, %.9f, %.9f]', *_a[0])
        logger.info(self, '                 a2 [%.9f, %.9f, %.9f]', *_a[1])
        logger.info(self, '                 a3 [%.9f, %.9f, %.9f]', *_a[2])
        if self.verbose >= logger.DEBUG2:
            logger.debug2(self, 'Charge      Location')
            coords = self.mm_mol.atom_coords()
            charges = self.mm_mol.atom_charges()
            for i, z in enumerate(charges):
                logger.debug2(self, '%.9g    %s', z, coords[i])
        return self

    def get_mm_ewald_pot(self, mol, mm_mol):
        return self.mm_mol.get_ewald_pot(
            mol.atom_coords(),
            mm_mol.atom_coords(), mm_mol.atom_charges(), mm_mol.get_zetas())

    def get_qm_ewald_pot(self, mol, dm, qm_ewald_hess=None):
        # hess = d^2 E / dQ_i dQ_j, d^2 E / dQ_i dD_ja, d^2 E / dDia dDjb, d^2 E/ dQ_i dO_jab
        if qm_ewald_hess is None:
            qm_ewald_hess = self.mm_mol.get_ewald_pot(mol.atom_coords())
            self.qm_ewald_hess = qm_ewald_hess
        dm = cp.asarray(dm)
        charges = self.get_qm_charges(dm)
        dips = self.get_qm_dipoles(dm)
        quads = self.get_qm_quadrupoles(dm)
        ewpot0  = contract('ij,j->i', qm_ewald_hess[0], charges)
        ewpot0 += contract('ijx,jx->i', qm_ewald_hess[1], dips)
        ewpot0 += contract('ijxy,jxy->i', qm_ewald_hess[3], quads)
        ewpot1  = contract('ijx,i->jx', qm_ewald_hess[1], charges)
        ewpot1 += contract('ijxy,jy->ix', qm_ewald_hess[2], dips)
        ewpot2  = contract('ijxy,j->ixy', qm_ewald_hess[3], charges)
        return ewpot0, ewpot1, ewpot2

    def get_hcore(self, mol=None):
        cput0 = (logger.process_clock(), logger.perf_counter())
        mm_mol = self.mm_mol
        if mol is None:
            mol = self.mol
        rcut_hcore = mm_mol.rcut_hcore

        h1e = cp.asarray(super().get_hcore(mol))

        Ls = mm_mol.get_lattice_Ls()
        qm_center = np.mean(mol.atom_coords(), axis=0)

        mask = np.linalg.norm(Ls, axis=-1) < 1e-12
        Ls[mask] = [np.inf] * 3
        r_qm = (mol.atom_coords() - qm_center)[None,:,:] - Ls[:,None,:]
        r_qm = np.einsum('Lix,Lix->Li', r_qm, r_qm)
        assert rcut_hcore**2 < np.min(r_qm), \
             "QM image is within rcut_hcore of QM center. " + \
            f"rcut_hcore = {rcut_hcore} >= min(r_qm) = {np.sqrt(np.min(r_qm))}"
        Ls[Ls == np.inf] = 0.0

        r_qm = mol.atom_coords() - qm_center
        r_qm = np.einsum('ix,ix->i', r_qm, r_qm)
        assert rcut_hcore**2 > np.max(r_qm), \
             "Not all QM atoms are within rcut_hcore of QM center. " + \
            f"rcut_hcore = {rcut_hcore} <= max(r_qm) = {np.sqrt(np.max(r_qm))}"
        r_qm = None

        qm_center = cp.asarray(qm_center)
        all_coords = cp.asarray((mm_mol.atom_coords()[None,:,:]
                + Ls[:,None,:]).reshape(-1,3))
        all_charges = cp.hstack([mm_mol.atom_charges()] * len(Ls))
        dist2 = all_coords - qm_center
        dist2 = contract('ix,ix->i', dist2, dist2)

        # charges within rcut_hcore exactly go into hcore
        mask = dist2 <= rcut_hcore**2
        charges = all_charges[mask]
        coords = all_coords[mask]
        logger.note(self, '%d MM charges see directly QM density'%charges.shape[0])
        if mm_mol.charge_model == 'gaussian' and len(coords) != 0:
            expnts = cp.hstack([mm_mol.get_zetas()] * len(Ls))[mask]
            h1e += int1e_grids(mol, coords, charges = -charges, charge_exponents = expnts)
        elif mm_mol.charge_model != 'point' and len(coords) != 0:
            # TODO test this block
            raise RuntimeError("Not tested yet")
            nao = mol.nao
            max_memory = self.max_memory - lib.current_memory()[0]
            blksize = int(min(max_memory*1e6/8/nao**2, 200))
            blksize = max(blksize, 1)
            for i0, i1 in lib.prange(0, charges.size, blksize):
                j3c = mol.intor('int1e_grids', hermi=1, grids=coords[i0:i1].get())
                h1e += contract('kpq,k->pq', cp.asarray(j3c), -charges[i0:i1])
        else: # no MM charges
            pass

        j3c = None
        logger.timer(self, 'get_hcore', *cput0)
        return h1e

    def get_qm_charges(self, dm):
        dm = cp.asarray(dm)
        aoslices = self.mol.aoslice_by_atom()
        chg = self.mol.atom_charges()
        dmS = cp.dot(dm, cp.asarray(self.get_ovlp()))
        qm_charges = list()
        for iatm in range(self.mol.natm):
            p0, p1 = aoslices[iatm, 2:]
            qm_charges.append(chg[iatm] - np.trace(dmS[p0:p1, p0:p1]))
        return cp.asarray(qm_charges)

    def get_s1r(self):
        if self.s1r is None:
            cput0 = (logger.process_clock(), logger.perf_counter())
            self.s1r = list()
            mol = self.mol
            bas_atom = mol._bas[:,gto.ATOM_OF]
            for i in range(self.mol.natm):
                b0, b1 = np.where(bas_atom == i)[0][[0,-1]]
                shls_slice = (0, mol.nbas, b0, b1+1)
                with mol.with_common_orig(mol.atom_coord(i)):
                    self.s1r.append(
                        cp.asarray(mol.intor('int1e_r', shls_slice=shls_slice)))
            logger.timer(self, 'get_s1r', *cput0)
        return self.s1r

    def get_qm_dipoles(self, dm, s1r=None):
        dm = cp.asarray(dm)
        if s1r is None:
            s1r = self.get_s1r()
        aoslices = self.mol.aoslice_by_atom()
        qm_dipoles = list()
        for iatm in range(self.mol.natm):
            p0, p1 = aoslices[iatm, 2:]
            qm_dipoles.append(
                -contract('uv,xvu->x', dm[p0:p1], s1r[iatm]))
        return cp.asarray(qm_dipoles)

    def get_s1rr(self):
        r'''
        .. math:: \int phi_u phi_v [3(r-Rc)\otimes(r-Rc) - |r-Rc|^2] /2 dr
        '''
        if self.s1rr is None:
            cput0 = (logger.process_clock(), logger.perf_counter())
            self.s1rr = list()
            mol = self.mol
            nao = mol.nao
            bas_atom = mol._bas[:,gto.ATOM_OF]
            for i in range(self.mol.natm):
                b0, b1 = np.where(bas_atom == i)[0][[0,-1]]
                shls_slice = (0, mol.nbas, b0, b1+1)
                with mol.with_common_orig(mol.atom_coord(i)):
                    s1rr_ = mol.intor('int1e_rr', shls_slice=shls_slice)
                    s1rr_ = s1rr_.reshape((3,3,nao,-1))
                    s1rr_trace = lib.einsum('xxuv->uv', s1rr_)
                    s1rr_ = 3/2 * s1rr_
                    for k in range(3):
                        s1rr_[k,k] -= 0.5 * s1rr_trace
                    self.s1rr.append(cp.asarray(s1rr_))
            logger.timer(self, 'get_s1rr', *cput0)
        return self.s1rr

    def get_qm_quadrupoles(self, dm, s1rr=None):
        dm = cp.asarray(dm)
        if s1rr is None:
            s1rr = self.get_s1rr()
        aoslices = self.mol.aoslice_by_atom()
        qm_quadrupoles = list()
        for iatm in range(self.mol.natm):
            p0, p1 = aoslices[iatm, 2:]
            qm_quadrupoles.append(
                -contract('uv,xyvu->xy', dm[p0:p1], s1rr[iatm]))
        return cp.asarray(qm_quadrupoles)

    def get_vdiff(self, mol, ewald_pot):
        '''
        vdiff_uv = d Q_I / d dm_uv ewald_pot[0]_I
                 + d D_Ix / d dm_uv ewald_pot[1]_Ix
                 + d O_Ixy / d dm_uv ewald_pot[2]_Ixy
        '''
        vdiff = cp.zeros((mol.nao, mol.nao))
        ovlp = self.get_ovlp()
        s1r  = self.get_s1r()
        s1rr = self.get_s1rr()
        aoslices = mol.aoslice_by_atom()
        for iatm in range(mol.natm):
            v0 = cp.asarray(ewald_pot[0][iatm])
            v1 = cp.asarray(ewald_pot[1][iatm])
            v2 = cp.asarray(ewald_pot[2][iatm])
            p0, p1 = aoslices[iatm, 2:]
            vdiff[:,p0:p1] -= v0 * ovlp[:,p0:p1]
            vdiff[:,p0:p1] -= contract('x,xuv->uv', v1, s1r[iatm])
            vdiff[:,p0:p1] -= contract('xy,xyuv->uv', v2, s1rr[iatm])
        vdiff = (vdiff + vdiff.T) / 2
        return vdiff

    def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1,
                 mm_ewald_pot=None, qm_ewald_pot=None):
        if mol is None:
            mol = self.mol
        mm_mol = self.mm_mol

        if mm_ewald_pot is None:
            if self.mm_ewald_pot is not None:
                mm_ewald_pot = self.mm_ewald_pot
            else:
                cput0 = (logger.process_clock(), logger.perf_counter())
                mm_ewald_pot = self.get_mm_ewald_pot(mol, mm_mol)
                self.mm_ewald_pot = mm_ewald_pot
                logger.timer(self, 'get_mm_ewald_pot', *cput0)
        if qm_ewald_pot is None:
            if self.qm_ewald_hess is not None:
                qm_ewald_pot = self.get_qm_ewald_pot(
                        mol, dm, self.qm_ewald_hess)
            else:
                cput0 = (logger.process_clock(), logger.perf_counter())
                qm_ewald_pot = self.get_qm_ewald_pot(mol, dm)
                logger.timer(self, 'get_qm_ewald_pot', *cput0)

        ewald_pot = \
            mm_ewald_pot[0] + qm_ewald_pot[0], \
            mm_ewald_pot[1] + qm_ewald_pot[1], \
            mm_ewald_pot[2] + qm_ewald_pot[2]
        vdiff = self.get_vdiff(mol, ewald_pot)

        if vhf_last is not None and isinstance(vhf_last, cupy_helper.CPArrayWithTag):
            vhf_last = vhf_last.veff_rs

        veff = super().get_veff(mol, dm, dm_last, vhf_last, hermi)
        if isinstance(veff, cupy_helper.CPArrayWithTag):
            metadata = veff.__dict__
            veff = cupy_helper.tag_array(veff + vdiff, veff_rs=veff, **metadata)
        else:
            veff = cupy_helper.tag_array(veff + vdiff, veff_rs=veff)
        return veff

    def energy_elec(self, dm=None, h1e=None, vhf=None):
        if vhf is None:
            # use the original veff to compute energy
            vhf = super().get_veff(self.mol, dm)
            return super().energy_elec(dm=dm, h1e=h1e, vhf=vhf)
        else:
            return super().energy_elec(dm=dm, h1e=h1e, vhf=vhf.veff_rs)

    def energy_ewald(self, dm=None, mm_ewald_pot=None, qm_ewald_pot=None):
        # QM-QM and QM-MM pbc correction
        if dm is None:
            dm = self.make_rdm1()
        else:
            dm = cp.asarray(dm)
        if mm_ewald_pot is None:
            if self.mm_ewald_pot is not None:
                mm_ewald_pot = self.mm_ewald_pot
            else:
                mm_ewald_pot = self.get_mm_ewald_pot(self.mol, self.mm_mol)
        if qm_ewald_pot is None:
            qm_ewald_pot = self.get_qm_ewald_pot(self.mol, dm, self.qm_ewald_hess)
        ewald_pot = mm_ewald_pot[0] + qm_ewald_pot[0] / 2
        e  = contract('i,i->', cp.asarray(ewald_pot), self.get_qm_charges(dm))
        ewald_pot = mm_ewald_pot[1] + qm_ewald_pot[1] / 2
        e += contract('ix,ix->', cp.asarray(ewald_pot), self.get_qm_dipoles(dm))
        ewald_pot = mm_ewald_pot[2] + qm_ewald_pot[2] / 2
        e += contract('ixy,ixy->', cp.asarray(ewald_pot), self.get_qm_quadrupoles(dm))
        # TODO add energy correction if sum(charges) !=0 ?
        return e

    def energy_nuc(self):
        if self.e_nuc is not None:
            return self.e_nuc
        else:
            cput0 = (logger.process_clock(), logger.perf_counter())
            from scipy.special import erf
            # gas phase nuc energy
            nuc = self.mol.energy_nuc()    # qm_nuc - qm_nuc

            # select mm atoms within rcut_hcore
            mol = self.mol
            Ls = self.mm_mol.get_lattice_Ls()
            qm_center = np.mean(mol.atom_coords(), axis=0)
            all_coords = lib.direct_sum('ix+Lx->Lix',
                self.mm_mol.atom_coords(), Ls).reshape(-1,3)
            all_charges = np.hstack([self.mm_mol.atom_charges()] * len(Ls))
            all_expnts = np.hstack([np.sqrt(self.mm_mol.get_zetas())] * len(Ls))
            dist2 = all_coords - qm_center
            dist2 = lib.einsum('ix,ix->i', dist2, dist2)
            mask = dist2 <= self.mm_mol.rcut_hcore**2
            charges = all_charges[mask]
            coords = all_coords[mask]
            expnts = all_expnts[mask]

            # qm_nuc - mm_pc
            for j in range(self.mol.natm):
                q2, r2 = self.mol.atom_charge(j), self.mol.atom_coord(j)
                r = lib.norm(r2-coords, axis=1)
                nuc += q2*(charges * erf(expnts*r) /r).sum()
            logger.timer(self, 'energy_nuc', *cput0)
            self.e_nuc = nuc
            return nuc

    def energy_tot(self, dm=None, h1e=None, vhf=None, mm_ewald_pot=None, qm_ewald_pot=None):
        nuc = self.energy_nuc()
        ewald = self.energy_ewald(dm=dm, mm_ewald_pot=mm_ewald_pot, qm_ewald_pot=qm_ewald_pot)
        e_tot = self.energy_elec(dm, h1e, vhf)[0] + nuc + ewald
        self.scf_summary['nuc'] = nuc.real
        self.scf_summary['ewald'] = ewald
        return e_tot

    def nuc_grad_method(self):
        scf_grad = super().nuc_grad_method()
        return qmmm_grad_for_scf(scf_grad)

    Gradients = nuc_grad_method


def add_mm_charges_grad(scf_grad, atoms_or_coords, a, charges, radii=None,
        rcut_ewald=None, rcut_hcore=None, unit=None):
    '''Apply the MM charges in the QM gradients' method.  It affects both the
    electronic and nuclear parts of the QM fragment.

    Args:
        scf_grad : a HF or DFT gradient object (grad.HF or grad.RKS etc)
            Once the add_mm_charges_grad was applied, it affects all post-HF
            calculations eg MP2, CCSD, MCSCF etc
        coords : 2D array, shape (N,3)
            MM particle coordinates
        charges : 1D array
            MM particle charges
        a : 2D array, shape (3,3)
            Lattice vectors
    Kwargs:
        radii : 1D array
            The Gaussian charge distribution radii of MM atoms.
        rcut_ewald : float
            The real-space Ewald cutoff.
        rcut_hcore : float
            The cutoff for exact MM potential when computing hcore.
        unit : str
            Bohr, AU, Ang (case insensitive). Default is the same to mol.unit

    Returns:
        Same gradeints method object as the input scf_grad method
    '''
    assert (isinstance(scf_grad, grad.rhf.Gradients))
    mol = scf_grad.mol
    if unit is None:
        unit = mol.unit
    mm_mol = mm_mole.create_mm_mol(atoms_or_coords, a, charges, radii=radii,
            rcut_ewald=rcut_ewald, rcut_hcore=rcut_hcore, unit=unit)
    mm_grad = qmmm_grad_for_scf(scf_grad)
    mm_grad.base.mm_mol = mm_mol
    return mm_grad

# Define method mm_charge_grad for backward compatibility
mm_charge_grad = add_mm_charges_grad

def get_hcore_mm(scf_grad, mol=None):
    '''
    Nuclear gradients of the electronic energy, w.r.t atomic orbitals
    '''
    mm_mol = scf_grad.base.mm_mol
    if mol is None:
        mol = scf_grad.mol
    g_qm = cp.zeros([3, mol.nao, mol.nao])
    rcut_hcore = mm_mol.rcut_hcore
    coords = cp.asarray(mm_mol.atom_coords())
    charges = cp.asarray(mm_mol.atom_charges())

    Ls = cp.asarray(mm_mol.get_lattice_Ls())

    qm_center = cp.mean(cp.asarray(mol.atom_coords()), axis=0)
    all_coords = (coords[None,:,:] + Ls[:,None,:]).reshape(-1,3)
    all_charges = cp.hstack([charges] * len(Ls))
    dist2 = all_coords - qm_center
    dist2 = contract('ix,ix->i', dist2, dist2)

    # charges within rcut_hcore exactly go into hcore
    mask = dist2 <= rcut_hcore**2
    charges = all_charges[mask]
    coords = all_coords[mask]
    nao = mol.nao
    if mm_mol.charge_model == 'gaussian' and len(coords) != 0:
        expnts = cp.hstack([mm_mol.get_zetas()] * len(Ls))[mask]
        g_qm += int1e_grids_ip1(mol, coords, charges = charges, charge_exponents = expnts)
    elif mm_mol.charge_model == 'point' and len(coords) != 0:
        raise RuntimeError("Not tested yet")
        max_memory = scf_grad.max_memory - lib.current_memory()[0]
        blksize = int(min(max_memory*1e6/8/nao**2/3, 200))
        blksize = max(blksize, 1)
        coords = coords.get()
        for i0, i1 in lib.prange(0, len(coords), blksize):
            j3c = cp.asarray(mol.intor('int1e_grids_ip', grids=coords[i0:i1]))
            g_qm += contract('ikpq,k->ipq', j3c, charges[i0:i1])
    else: # len(coords) == 0
        pass
    return g_qm

def qmmm_grad_for_scf(scf_grad):
    '''Add the potential of MM particles to SCF (HF and DFT) object and then
    generate the corresponding QM/MM gradients method for the total system.
    '''
    if getattr(scf_grad.base, 'with_x2c', None):
        raise NotImplementedError('X2C with QM/MM charges')

    # Avoid to initialize QMMMGrad twice
    if isinstance(scf_grad, QMMMGrad):
        return scf_grad

    assert (isinstance(scf_grad.base, scf.hf.SCF) and
           isinstance(scf_grad.base, QMMM))

    scf_grad.de_ewald_mm = None
    scf_grad.de_nuc_mm = None
    return scf_grad.view(lib.make_class((QMMMGrad, scf_grad.__class__)))

class QMMMGrad:
    __name_mixin__ = 'QMMM'
    _keys = {'de_ewald_mm', 'de_nuc_mm'}

    to_cpu = NotImplemented

    def __init__(self, scf_grad):
        self.__dict__.update(scf_grad.__dict__)

    def dump_flags(self, verbose=None):
        super().dump_flags(verbose)
        logger.info(self, '** Add background charges for %s **', self.__class__.__name__)
        _a = self.base.mm_mol.lattice_vectors()
        logger.info(self, 'lattice vectors  a1 [%.9f, %.9f, %.9f]', *_a[0])
        logger.info(self, '                 a2 [%.9f, %.9f, %.9f]', *_a[1])
        logger.info(self, '                 a3 [%.9f, %.9f, %.9f]', *_a[2])
        if self.verbose >= logger.DEBUG2:
            logger.debug2(self, 'Charge      Location')
            coords = self.base.mm_mol.atom_coords()
            charges = self.base.mm_mol.atom_charges()
            for i, z in enumerate(charges):
                logger.debug2(self, '%.9g    %s', z, coords[i])
        return self

    def grad_ewald(self, dm=None, with_mm=False, mm_ewald_pot=None, qm_ewald_pot=None):
        '''PBC correction energy grad w.r.t. qm and mm atom positions
        '''
        cput0 = (logger.process_clock(), logger.perf_counter())
        if dm is None: dm = self.base.make_rdm1()
        dm = cp.asarray(dm)
        mol = self.base.mol
        cell = self.base.mm_mol
        assert cell.dimension == 3
        qm_charges = self.base.get_qm_charges(dm)
        qm_dipoles = self.base.get_qm_dipoles(dm)
        qm_quads = self.base.get_qm_quadrupoles(dm)
        qm_coords = cp.asarray(self.base.mol.atom_coords())
        mm_charges = cp.asarray(self.base.mm_mol.atom_charges())
        mm_coords = cp.asarray(self.base.mm_mol.atom_coords())
        aoslices = mol.aoslice_by_atom()

        # nuc grad due to qm multipole change due to ovlp change
        qm_multipole_grad = cp.zeros_like(qm_coords)

        if mm_ewald_pot is None:
            if self.base.mm_ewald_pot is not None:
                mm_ewald_pot = self.base.mm_ewald_pot
            else:
                mm_ewald_pot = self.base.get_mm_ewald_pot(mol, cell)
        if qm_ewald_pot is None:
            qm_ewald_pot = self.base.get_qm_ewald_pot(mol, dm, self.base.qm_ewald_hess)
        ewald_pot = \
            mm_ewald_pot[0] + qm_ewald_pot[0], \
            mm_ewald_pot[1] + qm_ewald_pot[1], \
            mm_ewald_pot[2] + qm_ewald_pot[2]

        dEds = cp.zeros((mol.nao, mol.nao))
        dEdsr = cp.zeros((3, mol.nao, mol.nao))
        dEdsrr = cp.zeros((3, 3, mol.nao, mol.nao))
        s1 = cp.asarray(self.get_ovlp(mol)) # = -mol.intor('int1e_ipovlp')
        s1r = list()
        s1rr = list()
        bas_atom = mol._bas[:,gto.ATOM_OF]
        for iatm in range(mol.natm):
            v0 = cp.asarray(ewald_pot[0][iatm])
            v1 = cp.asarray(ewald_pot[1][iatm])
            v2 = cp.asarray(ewald_pot[2][iatm])
            p0, p1 = aoslices[iatm, 2:]

            dEds[p0:p1] -= v0 * dm[p0:p1]
            dEdsr[:,p0:p1] -= contract('x,uv->xuv', v1, dm[p0:p1])
            dEdsrr[:,:,p0:p1] -= contract('xy,uv->xyuv', v2, dm[p0:p1])

            b0, b1 = np.where(bas_atom == iatm)[0][[0,-1]]
            shlslc = (b0, b1+1, 0, mol.nbas)
            with mol.with_common_orig(qm_coords[iatm].get()):
                # s1r[a,x,u,v] = \int phi_u (r_a-Ri_a) (-\nabla_x phi_v) dr
                s1r.append(
                    cp.asarray(-mol.intor('int1e_irp', shls_slice=shlslc).
                               reshape(3, 3, -1, mol.nao)))
                # s1rr[a,b,x,u,v] =
                # \int phi_u [3/2*(r_a-Ri_a)(r_b-Ri_b)-1/2*(r-Ri)^2 delta_ab] (-\nable_x phi_v) dr
                s1rr_ = cp.asarray(-mol.intor('int1e_irrp', shls_slice=shlslc).
                            reshape(3, 3, 3, -1, mol.nao))
                s1rr_trace = cp.einsum('aaxuv->xuv', s1rr_)
                s1rr_ *= 3 / 2
                for k in range(3):
                    s1rr_[k,k] -= 0.5 * s1rr_trace
                s1rr.append(s1rr_)

        for jatm in range(mol.natm):
            p0, p1 = aoslices[jatm, 2:]

            # d E_qm_pc / d Ri with fixed ewald_pot
            qm_multipole_grad[jatm] += \
                contract('uv,xuv->x', dEds[p0:p1], s1[:,p0:p1]) \
              - contract('uv,xuv->x', dEds[:,p0:p1], s1[:,:,p0:p1])

            # d E_qm_dip / d Ri
            qm_multipole_grad[jatm] -= \
                 contract('auv,axuv->x', dEdsr[:,p0:p1], s1r[jatm])
            s1r_ = list()
            for iatm in range(mol.natm):
                s1r_.append(s1r[iatm][...,p0:p1])
            s1r_ = cp.concatenate(s1r_, axis=-2)
            qm_multipole_grad[jatm] += contract('auv,axuv->x', dEdsr[...,p0:p1], s1r_)

            # d E_qm_quad / d Ri
            qm_multipole_grad[jatm] -= \
                    contract('abuv,abxuv->x', dEdsrr[:,:,p0:p1], s1rr[jatm])
            s1rr_ = list()
            for iatm in range(mol.natm):
                s1rr_.append(s1rr[iatm][...,p0:p1])
            s1rr_ = cp.concatenate(s1rr_, axis=-2)
            qm_multipole_grad[jatm] += contract('abuv,abxuv->x', dEdsrr[...,p0:p1], s1rr_)

        cput1 = logger.timer(self, 'grad_ewald pulay', *cput0)
        s1 = s1r = s1rr = dEds = dEdsr = dEdsrr = None

        ew_eta, ew_cut = cell.get_ewald_params()

        # ---------------------------------------------- #
        # -------- Ewald real-space gradient ----------- #
        # ---------------------------------------------- #

        Lall = cp.asarray(cell.get_lattice_Ls())

        from pyscf import pbc
        rmax_qm = max(cp.linalg.norm(qm_coords - cp.mean(qm_coords, axis=0), axis=-1))
        qm_ewovrl_grad = cp.zeros_like(qm_coords)

        grad_Tij = lambda R, r: get_multipole_tensors_pp(R, [1,2,3], r)
        grad_kTij = lambda R, r, eta: get_multipole_tensors_pg(R, eta, [1,2,3], r)

        def grad_qm_multipole(Tija, Tijab, Tijabc,
                              qm_charges, qm_dipoles, qm_quads,
                              mm_charges):
            Tc   = contract('ijx,j->ix', Tija, mm_charges)
            res  = contract('i,ix->ix', qm_charges, Tc)
            Tc   = contract('ijxa,j->ixa', Tijab, mm_charges)
            res += contract('ia,ixa->ix', qm_dipoles, Tc)
            Tc   = contract('ijxab,j->ixab', Tijabc, mm_charges)
            res += contract('iab,ixab->ix', qm_quads, Tc) / 3
            return res

        def grad_mm_multipole(Tija, Tijab, Tijabc,
                              qm_charges, qm_dipoles, qm_quads,
                              mm_charges):
            Tc  = contract('i,ijx->jx', qm_charges, Tija)
            Tc += contract('ia,ijxa->jx', qm_dipoles, Tijab)
            Tc += contract('iab,ijxab->jx', qm_quads, Tijabc) / 3
            return contract('jx,j->jx', Tc, mm_charges)

        #------ qm - mm clasiical ewald energy gradient ------#
        all_mm_coords = (mm_coords[None,:,:] - Lall[:,None,:]).reshape(-1,3)
        all_mm_charges = cp.hstack([mm_charges] * len(Lall))
        dist2 = all_mm_coords - cp.mean(qm_coords, axis=0)[None]
        dist2 = contract('jx,jx->j', dist2, dist2)
        if with_mm:
            mm_ewovrl_grad = np.zeros_like(all_mm_coords)
        mem_avail = cupy_helper.get_avail_mem()
        blksize = int(mem_avail/64/3/len(all_mm_coords))
        if blksize == 0:
            raise RuntimeError(f"Not enough GPU memory, mem_avail = {mem_avail}, blkszie = {blksize}")
        for i0, i1 in lib.prange(0, mol.natm, blksize):
            R = qm_coords[i0:i1,None,:] - all_mm_coords[None,:,:]
            r = cp.linalg.norm(R, axis=-1)
            r[r<1e-16] = cp.inf

            # subtract the real-space Coulomb within rcut_hcore
            mask = dist2 <= cell.rcut_hcore**2
            Tija, Tijab, Tijabc = grad_Tij(R[:,mask], r[:,mask])
            qm_ewovrl_grad[i0:i1] -= grad_qm_multipole(Tija, Tijab, Tijabc,
                    qm_charges[i0:i1], qm_dipoles[i0:i1], qm_quads[i0:i1], all_mm_charges[mask])
            if with_mm:
                mm_ewovrl_grad[mask] += grad_mm_multipole(Tija, Tijab, Tijabc,
                        qm_charges[i0:i1], qm_dipoles[i0:i1], qm_quads[i0:i1], all_mm_charges[mask])

            # difference between MM gaussain charges and MM point charges
            mask = dist2 > cell.rcut_hcore**2
            zetas = cp.asarray(cell.get_zetas())
            min_expnt = cp.min(zetas)
            max_ewrcut = pbc.gto.cell._estimate_rcut(min_expnt, 0, 1., cell.precision)
            cut2 = (max_ewrcut + rmax_qm)**2
            mask = mask & (dist2 <= cut2)
            expnts = cp.hstack([cp.sqrt(zetas)] * len(Lall))[mask]
            if expnts.size != 0:
                Tija, Tijab, Tijabc = grad_kTij(R[:,mask], r[:,mask], expnts)
                qm_ewovrl_grad[i0:i1] -= grad_qm_multipole(Tija, Tijab, Tijabc,
                        qm_charges[i0:i1], qm_dipoles[i0:i1], qm_quads[i0:i1], all_mm_charges[mask])
                if with_mm:
                    mm_ewovrl_grad[mask] += grad_mm_multipole(Tija, Tijab, Tijabc,
                            qm_charges[i0:i1], qm_dipoles[i0:i1], qm_quads[i0:i1], all_mm_charges[mask])

            # ewald real-space sum
            cut2 = (ew_cut + rmax_qm)**2
            mask = dist2 <= cut2
            Tija, Tijab, Tijabc = grad_kTij(R[:,mask], r[:,mask], ew_eta)
            qm_ewovrl_grad[i0:i1] += grad_qm_multipole(Tija, Tijab, Tijabc,
                    qm_charges[i0:i1], qm_dipoles[i0:i1], qm_quads[i0:i1], all_mm_charges[mask])
            if with_mm:
                mm_ewovrl_grad[mask] -= grad_mm_multipole(Tija, Tijab, Tijabc,
                        qm_charges[i0:i1], qm_dipoles[i0:i1], qm_quads[i0:i1], all_mm_charges[mask])

        if with_mm:
            mm_ewovrl_grad = mm_ewovrl_grad.reshape(len(Lall), -1, 3)
            mm_ewovrl_grad = cp.sum(mm_ewovrl_grad, axis=0)
        all_mm_coords = all_mm_charges = None

        #------ qm - qm clasiical ewald energy gradient ------#
        R = qm_coords[:,None,:] - qm_coords[None,:,:]
        r = np.sqrt(contract('ijx,ijx->ij', R, R))
        r[r<1e-16] = 1e100

        # subtract the real-space Coulomb within rcut_hcore
        Tija, Tijab, Tijabc = grad_Tij(R, r)
        #qm_ewovrl_grad -= cp.einsum('i,ijx,j->ix', qm_charges, Tija, qm_charges)
        #qm_ewovrl_grad += cp.einsum('i,ijxa,ja->ix', qm_charges, Tijab, qm_dipoles)
        #qm_ewovrl_grad -= cp.einsum('i,ijxa,ja->jx', qm_charges, Tijab, qm_dipoles) #
        #qm_ewovrl_grad += cp.einsum('ia,ijxab,jb->ix', qm_dipoles, Tijabc, qm_dipoles)
        #qm_ewovrl_grad -= cp.einsum('i,ijxab,jab->ix', qm_charges, Tijabc, qm_quads) / 3
        #qm_ewovrl_grad += cp.einsum('i,ijxab,jab->jx', qm_charges, Tijabc, qm_quads) / 3 #
        temp = contract('ijx,j->ix', Tija, qm_charges)
        qm_ewovrl_grad -= contract('i,ix->ix', qm_charges, temp)
        temp = contract('ijxa,ja->ix', Tijab, qm_dipoles)
        qm_ewovrl_grad += contract('i,ix->ix', qm_charges, temp)
        temp = contract('i,ijxa->jxa', qm_charges, Tijab)
        qm_ewovrl_grad -= contract('jxa,ja->jx', temp, qm_dipoles) #
        temp = contract('ijxab,jb->ixa', Tijabc, qm_dipoles)
        qm_ewovrl_grad += contract('ia,ixa->ix', qm_dipoles, temp)
        temp = contract('ijxab,jab->ix', Tijabc, qm_quads)
        qm_ewovrl_grad -= contract('i,ix->ix', qm_charges, temp) / 3
        temp = contract('ijxab,jab->ijx', Tijabc, qm_quads)
        qm_ewovrl_grad += contract('i,ijx->jx', qm_charges, temp) / 3 #
        temp = None

        # ewald real-space sum
        # NOTE here I assume ewald real-space sum is over all qm images
        # consistent with mm_mole.get_ewald_pot
        R = (R[:,:,None,:] - Lall[None,None]).reshape(len(qm_coords), len(Lall)*len(qm_coords), 3)
        r = np.sqrt(contract('ijx,ijx->ij', R, R))
        r[r<1e-16] = 1e100
        Tija, Tijab, Tijabc = grad_kTij(R, r, ew_eta)
        Tija = Tija.reshape(len(qm_coords), len(qm_coords), len(Lall), 3)
        Tijab = Tijab.reshape(len(qm_coords), len(qm_coords), len(Lall), 3, 3)
        Tijabc = Tijabc.reshape(len(qm_coords), len(qm_coords), len(Lall), 3, 3, 3)
        #qm_ewovrl_grad += cp.einsum('i,ijLx,j->ix', qm_charges, Tija, qm_charges)
        #qm_ewovrl_grad -= cp.einsum('i,ijLxa,ja->ix', qm_charges, Tijab, qm_dipoles)
        #qm_ewovrl_grad += cp.einsum('i,ijLxa,ja->jx', qm_charges, Tijab, qm_dipoles) #
        #qm_ewovrl_grad -= cp.einsum('ia,ijLxab,jb->ix', qm_dipoles, Tijabc, qm_dipoles)
        #qm_ewovrl_grad += cp.einsum('i,ijLxab,jab->ix', qm_charges, Tijabc, qm_quads) / 3
        #qm_ewovrl_grad -= cp.einsum('i,ijLxab,jab->jx', qm_charges, Tijabc, qm_quads) / 3 #
        temp = contract('ijLx,j->ix', Tija, qm_charges)
        qm_ewovrl_grad += contract('i,ix->ix', qm_charges, temp)
        temp = contract('ijLxa,ja->ix', Tijab, qm_dipoles)
        qm_ewovrl_grad -= contract('i,ix->ix', qm_charges, temp)
        temp = contract('i,ijLxa->jxa', qm_charges, Tijab)
        qm_ewovrl_grad += contract('jxa,ja->jx', temp, qm_dipoles) #
        temp = contract('ijLxab,jb->ixa', Tijabc, qm_dipoles)
        qm_ewovrl_grad -= contract('ia,ixa->ix', qm_dipoles, temp)
        temp = contract('ijLxab,jab->ix', Tijabc, qm_quads)
        qm_ewovrl_grad += contract('i,ix->ix', qm_charges, temp) / 3
        temp = contract('i,ijLxab->jxab', qm_charges, Tijabc)
        qm_ewovrl_grad -= contract('jxab,jab->jx', temp, qm_quads) / 3 #

        cput2 = logger.timer(self, 'grad_ewald real-space', *cput1)

        # ---------------------------------------------- #
        # ---------- Ewald k-space gradient ------------ #
        # ---------------------------------------------- #

        mesh = cell.mesh
        Gv, Gvbase, weights = cell.get_Gv_weights(mesh)
        Gv = cp.asarray(Gv)
        absG2 = contract('gi,gi->g', Gv, Gv)
        absG2[absG2==0] = 1e200
        coulG = 4*np.pi / absG2
        coulG *= weights
        Gpref = cp.exp(-absG2/(4*ew_eta**2)) * coulG

        GvRmm = contract('gx,ix->ig', Gv, mm_coords)
        cosGvRmm = cp.cos(GvRmm)
        sinGvRmm = cp.sin(GvRmm)
        zcosGvRmm = contract("i,ig->g", mm_charges, cosGvRmm)
        zsinGvRmm = contract("i,ig->g", mm_charges, sinGvRmm)

        GvRqm = contract('gx,ix->ig', Gv, qm_coords)
        cosGvRqm = cp.cos(GvRqm)
        sinGvRqm = cp.sin(GvRqm)
        zcosGvRqm = contract("i,ig->g", qm_charges, cosGvRqm)
        zsinGvRqm = contract("i,ig->g", qm_charges, sinGvRqm)
        #DGcosGvRqm = cp.einsum("ia,ga,ig->g", qm_dipoles, Gv, cosGvRqm)
        #DGsinGvRqm = cp.einsum("ia,ga,ig->g", qm_dipoles, Gv, sinGvRqm)
        #TGGcosGvRqm = cp.einsum("iab,ga,gb,ig->g", qm_quads, Gv, Gv, cosGvRqm)
        #TGGsinGvRqm = cp.einsum("iab,ga,gb,ig->g", qm_quads, Gv, Gv, sinGvRqm)
        temp = contract('ia,ig->ag', qm_dipoles, cosGvRqm)
        DGcosGvRqm = contract('ga,ag->g', Gv, temp)
        temp = contract('ia,ig->ag', qm_dipoles, sinGvRqm)
        DGsinGvRqm = contract('ga,ag->g', Gv, temp)
        temp = contract('iab,ig->abg', qm_quads, cosGvRqm)
        temp = contract('abg,ga->bg', temp, Gv)
        TGGcosGvRqm = contract('gb,bg->g', Gv, temp)
        temp = contract('iab,ig->abg', qm_quads, sinGvRqm)
        temp = contract('abg,ga->bg', temp, Gv)
        TGGsinGvRqm = contract('gb,bg->g', Gv, temp)

        qm_ewg_grad = cp.zeros_like(qm_coords)
        if with_mm:
            mm_ewg_grad = cp.zeros_like(mm_coords)

        # qm pc - mm pc
        #p = ['einsum_path', (3, 4), (1, 3), (1, 2), (0, 1)]
        #qm_ewg_grad -= cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, sinGvRqm, zcosGvRmm, Gpref, optimize=p)
        #qm_ewg_grad += cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, cosGvRqm, zsinGvRmm, Gpref, optimize=p)
        temp = contract('g,g->g', zcosGvRmm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', sinGvRqm, temp)
        qm_ewg_grad -= contract('i,ix->ix', qm_charges, temp)
        temp = contract('g,g->g', zsinGvRmm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', cosGvRqm, temp)
        qm_ewg_grad += contract('i,ix->ix', qm_charges, temp)
        if with_mm:
            #p = ['einsum_path', (0, 2), (1, 2), (0, 2), (0, 1)]
            #mm_ewg_grad -= cp.einsum('i,gx,ig,g,g->ix', mm_charges, Gv, sinGvRmm, zcosGvRqm, Gpref, optimize=p)
            #mm_ewg_grad += cp.einsum('i,gx,ig,g,g->ix', mm_charges, Gv, cosGvRmm, zsinGvRqm, Gpref, optimize=p)
            temp = contract('i,ig->gi', mm_charges, sinGvRmm)
            temp2 = contract('g,g->g', zcosGvRqm, Gpref)
            temp2 = contract('gx,g->gx', Gv, temp2)
            mm_ewg_grad -= contract('gi,gx->ix', temp, temp2)
            temp = contract('i,ig->gi', mm_charges, cosGvRmm)
            temp2 = contract('g,g->g', zsinGvRqm, Gpref)
            temp2 = contract('gx,g->gx', Gv, temp2)
            mm_ewg_grad += contract('gi,gx->ix', temp, temp2)
        # qm dip - mm pc
        #p = ['einsum_path', (4, 5), (1, 4), (0, 1), (0, 2), (0, 1)]
        #qm_ewg_grad -= cp.einsum('ia,gx,ga,ig,g,g->ix', qm_dipoles, Gv, Gv, sinGvRqm, zsinGvRmm, Gpref, optimize=p)
        #qm_ewg_grad -= cp.einsum('ia,gx,ga,ig,g,g->ix', qm_dipoles, Gv, Gv, cosGvRqm, zcosGvRmm, Gpref, optimize=p)
        temp = contract('g,g->g', zsinGvRmm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp2 = contract('ia,ga->gi', qm_dipoles, Gv)
        temp2 = contract('ig,gi->ig', sinGvRqm, temp2)
        qm_ewg_grad -= contract('gx,ig->ix', temp, temp2)
        temp = contract('g,g->g', zcosGvRmm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp2 = contract('ia,ga->gi', qm_dipoles, Gv)
        temp2 = contract('ig,gi->ig', cosGvRqm, temp2)
        qm_ewg_grad -= contract('gx,ig->ix', temp, temp2)
        if with_mm:
            #p = ['einsum_path', (1, 3), (0, 2), (0, 2), (0, 1)]
            #mm_ewg_grad += cp.einsum('g,j,gx,jg,g->jx', DGcosGvRqm, mm_charges, Gv, cosGvRmm, Gpref, optimize=p)
            #mm_ewg_grad += cp.einsum('g,j,gx,jg,g->jx', DGsinGvRqm, mm_charges, Gv, sinGvRmm, Gpref, optimize=p)
            temp = contract('j,jg->gj', mm_charges, cosGvRmm)
            temp2 = contract('g,g->g', DGcosGvRqm, Gpref)
            temp2 = contract('gx,g->gx', Gv, temp2)
            mm_ewg_grad += contract('gj,gx->jx', temp, temp2)
            temp = contract('j,jg->gj', mm_charges, sinGvRmm)
            temp2 = contract('g,g->g', DGsinGvRqm, Gpref)
            temp2 = contract('gx,g->gx', Gv, temp2)
            mm_ewg_grad += contract('gj,gx->jx', temp, temp2)
        # qm quad - mm pc
        #p = ['einsum_path', (5, 6), (0, 5), (0, 2), (2, 3), (1, 2), (0, 1)]
        #qm_ewg_grad += cp.einsum('ga,gb,iab,gx,ig,g,g->ix', Gv, Gv, qm_quads, Gv, sinGvRqm, zcosGvRmm, Gpref, optimize=p) / 3
        #qm_ewg_grad -= cp.einsum('ga,gb,iab,gx,ig,g,g->ix', Gv, Gv, qm_quads, Gv, cosGvRqm, zsinGvRmm, Gpref, optimize=p) / 3
        temp = contract('g,g->g', zcosGvRmm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp2 = contract('gb,gx->gbx', Gv, Gv)
        temp = contract('ag,gbx->abgx', temp, temp2)
        temp = contract('ig,abgx->iabx', sinGvRqm, temp)
        qm_ewg_grad += contract('iab,iabx->ix', qm_quads, temp) / 3
        temp = contract('g,g->g', zsinGvRmm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp2 = contract('gb,gx->gbx', Gv, Gv)
        temp = contract('ag,gbx->abgx', temp, temp2)
        temp = contract('ig,abgx->iabx', cosGvRqm, temp)
        qm_ewg_grad -= contract('iab,iabx->ix', qm_quads, temp) / 3
        if with_mm:
            #p = ['einsum_path', (1, 3), (0, 2), (0, 2), (0, 1)]
            #mm_ewg_grad += cp.einsum('g,j,gx,jg,g->jx', TGGcosGvRqm, mm_charges, Gv, sinGvRmm, Gpref, optimize=p) / 3
            #mm_ewg_grad -= cp.einsum('g,j,gx,jg,g->jx', TGGsinGvRqm, mm_charges, Gv, cosGvRmm, Gpref, optimize=p) / 3
            temp = contract('j,jg->gj', mm_charges, sinGvRmm)
            temp2 = contract('g,g->g', TGGcosGvRqm, Gpref)
            temp2 = contract('gx,g->gx', Gv, temp2)
            mm_ewg_grad += contract('gj,gx->jx', temp, temp2) / 3
            temp = contract('j,jg->gj', mm_charges, cosGvRmm)
            temp2 = contract('g,g->g', TGGsinGvRqm, Gpref)
            temp2 = contract('gx,g->gx', Gv, temp2)
            mm_ewg_grad -= contract('gj,gx->jx', temp, temp2) / 3

        # qm pc - qm pc
        #p = ['einsum_path', (3, 4), (1, 3), (1, 2), (0, 1)]
        #qm_ewg_grad -= cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, sinGvRqm, zcosGvRqm, Gpref, optimize=p)
        #qm_ewg_grad += cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, cosGvRqm, zsinGvRqm, Gpref, optimize=p)
        temp = contract('g,g->g', zcosGvRqm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', sinGvRqm, temp)
        qm_ewg_grad -= contract('i,ix->ix', qm_charges, temp)
        temp = contract('g,g->g', zsinGvRqm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', cosGvRqm, temp)
        qm_ewg_grad += contract('i,ix->ix', qm_charges, temp)
        # qm pc - qm dip
        #qm_ewg_grad += cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, cosGvRqm, DGcosGvRqm, Gpref, optimize=p)
        #qm_ewg_grad += cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, sinGvRqm, DGsinGvRqm, Gpref, optimize=p)
        temp = contract('g,g->g', DGcosGvRqm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', cosGvRqm, temp)
        qm_ewg_grad += contract('i,ix->ix', qm_charges, temp)
        temp = contract('g,g->g', DGsinGvRqm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', sinGvRqm, temp)
        qm_ewg_grad += contract('i,ix->ix', qm_charges, temp)
        #p = ['einsum_path', (3, 5), (1, 4), (1, 3), (1, 2), (0, 1)]
        #qm_ewg_grad -= cp.einsum('ja,ga,gx,g,jg,g->jx', qm_dipoles, Gv, Gv, zsinGvRqm, sinGvRqm, Gpref, optimize=p)
        #qm_ewg_grad -= cp.einsum('ja,ga,gx,g,jg,g->jx', qm_dipoles, Gv, Gv, zcosGvRqm, cosGvRqm, Gpref, optimize=p)
        temp = contract('g,g->g', zsinGvRqm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp = contract('gx,ag->axg', Gv, temp)
        temp = contract('jg,axg->ajx', sinGvRqm, temp)
        qm_ewg_grad -= contract('ja,ajx->jx', qm_dipoles, temp)
        temp = contract('g,g->g', zcosGvRqm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp = contract('gx,ag->axg', Gv, temp)
        temp = contract('jg,axg->ajx', cosGvRqm, temp)
        qm_ewg_grad -= contract('ja,ajx->jx', qm_dipoles, temp)
        # qm dip - qm dip
        #p = ['einsum_path', (4, 5), (1, 4), (1, 3), (1, 2), (0, 1)]
        #qm_ewg_grad -= cp.einsum('ia,ga,gx,ig,g,g->ix', qm_dipoles, Gv, Gv, sinGvRqm, DGcosGvRqm, Gpref, optimize=p)
        #qm_ewg_grad += cp.einsum('ia,ga,gx,ig,g,g->ix', qm_dipoles, Gv, Gv, cosGvRqm, DGsinGvRqm, Gpref, optimize=p)
        temp = contract('g,g->g', DGcosGvRqm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp = contract('gx,ag->axg', Gv, temp)
        temp = contract('ig,axg->aix', sinGvRqm, temp)
        qm_ewg_grad -= contract('ia,aix->ix', qm_dipoles, temp)
        temp = contract('g,g->g', DGsinGvRqm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp = contract('gx,ag->axg', Gv, temp)
        temp = contract('ig,axg->aix', cosGvRqm, temp)
        qm_ewg_grad += contract('ia,aix->ix', qm_dipoles, temp)
        # qm pc - qm quad
        #p = ['einsum_path', (3, 4), (1, 3), (1, 2), (0, 1)]
        #qm_ewg_grad += cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, sinGvRqm, TGGcosGvRqm, Gpref, optimize=p) / 3
        #qm_ewg_grad -= cp.einsum('i,gx,ig,g,g->ix', qm_charges, Gv, cosGvRqm, TGGsinGvRqm, Gpref, optimize=p) / 3
        temp = contract('g,g->g', TGGcosGvRqm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', sinGvRqm, temp)
        qm_ewg_grad += contract('i,ix->ix', qm_charges, temp) / 3
        temp = contract('g,g->g', TGGsinGvRqm, Gpref)
        temp = contract('gx,g->gx', Gv, temp)
        temp = contract('ig,gx->ix', cosGvRqm, temp)
        qm_ewg_grad -= contract('i,ix->ix', qm_charges, temp) / 3
        #p = ['einsum_path', (4, 6), (1, 5), (1, 2), (2, 3), (1, 2), (0, 1)]
        #qm_ewg_grad += cp.einsum('jab,ga,gb,gx,g,jg,g->jx', qm_quads, Gv, Gv, Gv, zcosGvRqm, sinGvRqm, Gpref, optimize=p) / 3
        #qm_ewg_grad -= cp.einsum('jab,ga,gb,gx,g,jg,g->jx', qm_quads, Gv, Gv, Gv, zsinGvRqm, cosGvRqm, Gpref, optimize=p) / 3
        temp = contract('g,g->g', zcosGvRqm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp2 = contract('gb,gx->bgx', Gv, Gv)
        temp = contract('ag,bgx->abgx', temp, temp2)
        temp = contract('jg,abgx->abjx', sinGvRqm, temp)
        qm_ewg_grad += contract('jab,abjx->jx', qm_quads, temp) / 3
        temp = contract('g,g->g', zsinGvRqm, Gpref)
        temp = contract('ga,g->ag', Gv, temp)
        temp2 = contract('gb,gx->bgx', Gv, Gv)
        temp = contract('ag,bgx->abgx', temp, temp2)
        temp = contract('jg,abgx->abjx', cosGvRqm, temp)
        qm_ewg_grad -= contract('jab,abjx->jx', qm_quads, temp) / 3

        logger.timer(self, 'grad_ewald k-space', *cput2)
        logger.timer(self, 'grad_ewald', *cput0)
        if not with_mm:
            return (qm_multipole_grad + qm_ewovrl_grad + qm_ewg_grad).get()
        else:
            return (qm_multipole_grad + qm_ewovrl_grad + qm_ewg_grad).get(), \
                   (mm_ewovrl_grad + mm_ewg_grad).get()

    def get_hcore(self, mol=None, exclude_ecp=False):
        if mol is None:
            mol = self.mol
        cput0 = (logger.process_clock(), logger.perf_counter())
        def calculate_h1e(self, h1_gpu):
            h1 = super().get_hcore(mol, exclude_ecp=exclude_ecp)
            h1_gpu[:] = cp.asarray(h1)
            return

        g_qm_orig = cp.empty([3, mol.nao, mol.nao])
        with lib.call_in_background(calculate_h1e) as calculate_hs:
            calculate_hs(self, g_qm_orig)
            g_qm = get_hcore_mm(self)

        logger.timer(self, 'get_hcore', *cput0)
        return g_qm_orig + g_qm

    def grad_hcore_mm(self, dm, mol=None):
        '''Nuclear gradients of the electronic energy, w.r.t charges
        '''
        cput0 = (logger.process_clock(), logger.perf_counter())
        dm = cp.asarray(dm)
        mm_mol = self.base.mm_mol
        if mol is None:
            mol = self.mol
        rcut_hcore = mm_mol.rcut_hcore

        coords = cp.asarray(mm_mol.atom_coords())
        charges = cp.asarray(mm_mol.atom_charges())
        expnts = cp.asarray(mm_mol.get_zetas())

        Ls = cp.asarray(mm_mol.get_lattice_Ls())

        qm_center = cp.mean(cp.asarray(mol.atom_coords()), axis=0)
        all_coords = (coords[None,:,:] + Ls[:,None,:]).reshape(-1,3)
        all_charges = cp.hstack([charges] * len(Ls))
        all_expnts = cp.hstack([expnts] * len(Ls))
        dist2 = all_coords - qm_center
        dist2 = contract('ix,ix->i', dist2, dist2)

        # charges within rcut_hcore exactly go into hcore
        mask = dist2 <= rcut_hcore**2
        charges = all_charges[mask]
        coords = all_coords[mask]
        expnts = all_expnts[mask]

        g = cp.zeros_like(all_coords)
        if len(coords) != 0:
            expnts = cp.hstack([mm_mol.get_zetas()] * len(Ls))[mask]
            g[mask] = int1e_grids_ip2(mol, coords, dm = dm, charges = charges, charge_exponents = expnts).T
        g = g.reshape(len(Ls), -1, 3)
        g = np.sum(g, axis=0)

        logger.timer(self, 'grad_hcore_mm', *cput0)
        return g.get()

    contract_hcore_mm = grad_hcore_mm

    def grad_nuc(self, mol=None, atmlst=None, with_mm=True):
        cput0 = (logger.process_clock(), logger.perf_counter())
        assert atmlst is None # atmlst needs to be full for computing g_mm
        if mol is None: mol = self.mol
        mm_mol = self.base.mm_mol
        coords = cp.asarray(mm_mol.atom_coords())
        charges = cp.asarray(mm_mol.atom_charges())
        Ls = cp.asarray(mm_mol.get_lattice_Ls())
        qm_center = cp.mean(cp.asarray(mol.atom_coords()), axis=0)
        all_coords = (coords[None,:,:] + Ls[:,None,:]).reshape(-1,3)
        all_charges = cp.hstack([charges] * len(Ls))
        all_expnts = cp.hstack([cp.sqrt(cp.asarray(mm_mol.get_zetas()))] * len(Ls))
        dist2 = all_coords - qm_center
        dist2 = contract('ix,ix->i', dist2, dist2)
        mask = dist2 <= mm_mol.rcut_hcore**2
        charges = all_charges[mask]
        coords = all_coords[mask]
        expnts = all_expnts[mask]

        g_qm = cp.asarray(super().grad_nuc(atmlst))
        g_mm = cp.zeros_like(all_coords)
        for i in range(mol.natm):
            q1 = mol.atom_charge(i)
            r1 = cp.asarray(mol.atom_coord(i))
            r = cp.linalg.norm(r1-coords, axis=1)
            g_mm_  =  q1 * contract('ix,i->ix', r1-coords, charges * erf(expnts*r)/r**3)
            g_mm_ -=  q1 * contract('ix,i->ix', r1-coords, charges * expnts * 2 / np.sqrt(np.pi)
                              * cp.exp(-expnts**2 * r**2)/r**2)
            g_mm[mask] += g_mm_
            g_qm[i]    -= cp.sum(g_mm_, axis=0)
        g_mm = g_mm.reshape(len(Ls), -1, 3)
        g_mm = np.sum(g_mm, axis=0)
        self.de_nuc_mm = g_mm.get()
        logger.timer(self, 'grad_nuc', *cput0)
        return g_qm.get()

    def grad_nuc_mm(self, mol=None):
        if self.de_nuc_mm is not None:
            return self.de_nuc_mm
        cput0 = (logger.process_clock(), logger.perf_counter())
        from scipy.special import erf
        if mol is None:
            mol = self.mol
        mm_mol = self.base.mm_mol
        coords = mm_mol.atom_coords()
        charges = mm_mol.atom_charges()
        Ls = mm_mol.get_lattice_Ls()
        qm_center = np.mean(mol.atom_coords(), axis=0)
        all_coords = lib.direct_sum('ix+Lx->Lix',
                mm_mol.atom_coords(), Ls).reshape(-1,3)
        all_charges = np.hstack([mm_mol.atom_charges()] * len(Ls))
        all_expnts = np.hstack([np.sqrt(mm_mol.get_zetas())] * len(Ls))
        dist2 = all_coords - qm_center
        dist2 = lib.einsum('ix,ix->i', dist2, dist2)
        mask = dist2 <= mm_mol.rcut_hcore**2
        charges = all_charges[mask]
        coords = all_coords[mask]
        expnts = all_expnts[mask]

        g_mm = np.zeros_like(all_coords)
        for i in range(mol.natm):
            q1 = mol.atom_charge(i)
            r1 = mol.atom_coord(i)
            r = lib.norm(r1-coords, axis=1)
            g_mm[mask] += q1 * lib.einsum('i,ix,i->ix', charges, r1-coords, erf(expnts*r)/r**3)
            g_mm[mask] -= q1 * lib.einsum('i,ix,i->ix', charges * expnts * 2 / np.sqrt(np.pi),
                                          r1-coords, np.exp(-expnts**2 * r**2)/r**2)
        g_mm = g_mm.reshape(len(Ls), -1, 3)
        g_mm = np.sum(g_mm, axis=0)
        logger.timer(self, 'grad_nuc_mm', *cput0)
        return g_mm

    def _finalize(self):
        g_ewald_qm, self.de_ewald_mm = self.grad_ewald(with_mm=True)
        self.de += g_ewald_qm
        super()._finalize()
